In [None]:
import scipy.stats as sp
import numpy as np

In [None]:
normal = sp.distributions.norm(loc=0, scale=1)

In [None]:
samples = normal.rvs(10)
samples

In [None]:
# for light tailed distr
from typing import Tuple
SMALL_VALUE: int = - 1e4
LARGE_VALUE: int = 1e4

def sign(x): return 1 if x >= 0 else -1

def interpolate(x_min, x_max, n) -> list:
    return np.linspace(x_min, x_max, n).tolist()

def find_min_max(distr, thresh: float) -> Tuple[float, float]:
    bd = (LARGE_VALUE + SMALL_VALUE) / 2
    found_min = False
    found_max = False
    x_min = None
    x_max = None

    while not (found_min):

        print(bd)
        if abs(distr(bd) - thresh) < 1e-5: 
            x_min = bd
            found_min = True
        else:
            # too far away -> adjust
            bd += (distr(bd) - thresh) * (bd + SMALL_VALUE)
        
    # reset, start search again for max val
    bd = (LARGE_VALUE + SMALL_VALUE) / 2
    while not (found_max):
        print(bd)
        if abs(distr(bd) - (1 - thresh)) < 1e-5: 
            x_max = bd
            found_max = True
        else:
            bd -= (distr(bd) - (1-thresh)) * (bd + LARGE_VALUE)
    return x_min, x_max
    

def approx(n: int, distr: callable):
    # x_min, x_max = find_min_max(distr)
    x_min, x_max = -20, 20
    # interpolate
    xs = interpolate(x_min, x_max, n)
    shift_left = [SMALL_VALUE]
    shift_left.extend(xs)
    xs.extend([LARGE_VALUE])
    

    # determine weights for bins
    bins = [(x, y) for x,y in zip(shift_left, xs)]
    mid_points = [b[1] - b[0] for b in bins]
    probs = [(distr(b[1]) - distr(b[0])) for b in bins]
    return probs 




In [None]:
normal.cdf(-20)
# find_min_max(normal.cdf, 0.05)
temp = approx(100, normal.cdf)
# ls = interpolate(-20,20,10)
np.sum(temp)
np.count_nonzero(np.where(np.asarray(temp) >= 0))

# In theory, this works, needs work for find_min_max in a reliable way

# The following is the implementation of the projection in case only a CDF of the reward distribution is available

In [None]:
from typing import TypeAlias
Distribution: TypeAlias = Tuple[np.ndarray, np.ndarray]


from enum import Enum
class Direction(Enum):
    LEFT = 0 
    RIGHT = 1


class ProbMesh:
    def __init__(self, atoms: np.ndarray, probs: np.ndarray):
        self.atoms = atoms
        self.probs = probs
        self._sup_dist = self._sup_dist()

    def _sort(self):
        idcs: np.ndarray = np.argsort(self.atoms)
        self.atoms = self.atoms[idcs]
        self.probs = self.probs[idcs]

    def _sup_dist(self):
        sup_dist = np.max(np.diff(self.atoms, 1))
        return sup_dist

    def get_sup_dist(self): return self.sup_dist


def algo_cdf(prior_distr: Distribution, k: int) -> Tuple[np.ndarray, ...]:
    """Assume that atoms in increasing order."""

    min_thresh = 5 * np.exp(-(k+5))
    max_thresh = 1 - min_thresh
    inter_thresh = min_thresh
    # probably make this fixed e.g. 2^(k+1) atoms in k-th iteration
    # print("Treshholds: ", end="")
    # print(min_thresh, max_thresh, inter_thresh)

    v_min, p_min = prior_distr[0][0], prior_distr[1][0]
    v_max, p_max = prior_distr[0][-1], prior_distr[1][-1]
    left_extend = np.array([])
    right_extend = np.array([])
    mid_extend = np.array([])

    if prior_distr[1][0] > min_thresh:
        # extend to left
        left_extend = extend_support(v_min, v_max, k, Direction.LEFT)

    if (1 - prior_distr[1][-1]) > min_thresh:
        # extend to right
        right_extend = extend_support(v_min, v_max, k, Direction.RIGHT)

    emp_cdf = np.cumsum(prior_distr[1])
    if np.max(np.diff(emp_cdf, 1)) > inter_thresh:
        mid_extend = interpolate_atoms(prior_distr, k)
        # extend in the 'center'
        pass


    return left_extend, mid_extend, right_extend, prior_distr[0]

def extend_support(v_min: float, v_max: float,  k: int, direction: Direction) -> Tuple[np.ndarray, np.ndarray]:
    """Extend support, returning more points to the left and more points to the right"""
    # k >= 2 assumed
    no_new_points: int = k // 2 
    step_size = ( v_max - v_min ) * k
    # print(f"No of new points and step size: {no_new_points, step_size}")

    if direction == Direction.LEFT:
        new_points = [v_min - step_size * (i+1) for i in range(no_new_points)]
    elif direction == Direction.RIGHT:
        new_points = [v_max + step_size * (i+1) for i in range(no_new_points)]

    else:
        # print("No valid direction")
        pass
    return np.asarray(new_points)


def interpolate_atoms(prior_distr: Distribution, k: int) -> np.ndarray:
    """Let mesh -> 0 over time, adding more points to eval cdf at."""

    # assume k >= 2
    # assume atoms already in increasing order
    atoms = prior_distr[0]
    probs= prior_distr[1]
    no_new_particles: int = k // 2
    emp_cdf = np.cumsum(probs)
    emp_cdf_diff = np.diff(emp_cdf, 1)
    interpolation_order = np.argsort(emp_cdf_diff)[::-1]  # largest gap first
    # print("Emp cdf, diff", end="")
    # print(emp_cdf, emp_cdf_diff)

    new_particles = atoms[interpolation_order[:no_new_particles] + 1] + \
                          atoms[interpolation_order[:no_new_particles]]
    new_particles = new_particles / 2

    # print(f"new_particles: {new_particles}")
    # new_particles = [(atoms[interpolation_order[i]] - atoms[interpolation_order[i] - 1])/2
                     # for i in range(no_new_particles)]
    return new_particles



def project_cdf(distr_cdf: callable, param: Tuple[np.ndarray, ...]) -> Distribution:
    """Project with new parameter.
    
    Given left, right, and mid extend from algo_cdf, eval cdf at new points.
    Calc bins.
    """

    ARBITRARY = 100
    left, mid, right, prior_atoms = param
    mid_joined = np.concatenate([mid, prior_atoms])
    mid_joined = mid_joined[np.argsort(mid_joined)]
    new_atoms = np.concatenate([left, mid_joined, right])
    # print(f"New atoms: {new_atoms}")
    
    new_midpoints = (new_atoms[1:] + new_atoms[:-1]) / 2
    # print(f"New midpoints: {new_midpoints}")
    cdf_evals = distr_cdf(new_midpoints)

    # print(f"cdf evals: {cdf_evals}")
    new_probs = np.concatenate([cdf_evals[0:1], np.ediff1d(cdf_evals)])
    
    # print(f'new probs: {new_probs}')
    new_probs[-1] += (1 - np.sum(new_probs))
    # print(f"cum sum of probs: {np.sum(new_probs)}")
    assert np.isclose(np.sum(new_probs), 1) == True
    return (new_midpoints, new_probs)

In [None]:
a = np.random.randint(-10, 10, 5)
a

a = a[np.argsort(a)]
b = (a[1: ] + a[:-1]) / 2
a,b

a, np.ediff1d(a)

a, np.concatenate([a[0:1], np.ediff1d(a)])

In [None]:
a = np.random.rand(2)
b = np.random.rand(2)
a, b
np.concatenate([a,b])

In [None]:
atoms = np.random.randint(0,20, 5)
atoms = atoms[np.argsort(atoms)]
probs = np.random.rand(5)
probs = probs / np.sum(probs)
list(zip(atoms, probs))

In [None]:
d = (atoms, probs)
interpolate_atoms(d, 2)

In [None]:
extend_support(atoms[0], atoms[-1], 2, Direction.LEFT)


In [None]:

b = np.random.rand(5)
atoms = np.random.randint(0, 10, size=5)
probs = b / np.sum(b)

cs = np.cumsum(probs)
print(cs)
d_of_cs = np.diff(cs)
print(d_of_cs)
order = np.argsort(d_of_cs)[::-1]
print(np.argsort(d_of_cs)[::-1])


# np.sum(b / np.sum(b))






In [None]:
atoms

In [None]:
probs[order]

In [None]:
order 

In [None]:
atoms

In [None]:
(atoms[order] - atoms[order-1]).transpose()

In [None]:
atoms[order]

In [None]:
atoms[order-1]

In [None]:

algo_cdf((atoms, probs), 2)

In [None]:
approx_dist = (np.array([-10, 10]), np.array([.5, .5]))
for k in range(2, 20):
    param = algo_cdf(approx_dist, k)
    # approx_dist = project_cdf(norm.cdf, param)
    approx_dist = project_cdf(cauchy.cdf, param)


In [None]:
approx_dist

In [None]:
len(approx_dist[1]), len(approx_dist[0])


In [None]:
import scipy.stats as sp

In [None]:
norm = sp.norm(loc = 0, scale=16)

In [None]:
cauchy = sp.cauchy()

In [None]:
import matplotlib.pyplot as plt

In [None]:
lins = np.linspace(-100, 100, 10000)

In [None]:
distr = cauchy
# or distr = norm
plt.plot(lins, distr.cdf(lins))
plt.xlim(-20, 20)
plt.plot(approx_dist[0], np.cumsum(approx_dist[1]))
left, right = plt.xlim()