# Reservoir Sampling

Resources
- https://www.youtube.com/watch?v=Ybra0uGEkpM
- https://en.wikipedia.org/wiki/Reservoir_sampling
- https://stats.stackexchange.com/a/488233

**Goal:** sample from a stream of size $N$ with or without replacement.

**Context:**
- $N$ is not known a priori
- you cannot collect all items beforehand, because $N$ can be prohibitively large
- you are only presented with an item $x_i$ at a given moment in time ($i$-th step), with $1 \leq i \leq N$
- you can only make a decision (update some _internal state_) based on the current item $x_i$

**Guarantee:** at the end of the algorithm, you will obtain an unbiased sample result as you'd do, e.g. with:
```python
# sample w/o replacement
np.random.choice(xs, size=k, replace=False)

# sample w/ replacement
np.random.choice(xs, size=k, replace=True)

# where `xs` would be the entire contents of the stream
```

In [None]:
import numpy as np
from itertools import count
from collections import Counter
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
import math

In [None]:
def coin(p):
    return np.random.binomial(n=1, p=p, size=1).item()


def make_stream(num_items):
    """N distinct items"""
    xs = np.arange(num_items)
    np.random.shuffle(xs)
    for x in xs:
        yield x

## `k == 1`

In [None]:
def rs_once(stream, stop_at):
    cell = None

    # when i'th item xi arrives, with proba 1/i set R <- xi
    for i in range(1, math.ceil(stop_at + 1)):
        xi = next(stream)

        if coin(1 / i):
            cell = xi

    assert cell is not None
    return cell

At the end of the experiment, each item will be sampled with proba $\dfrac{1}{\texttt{num_items}}$

In [None]:
# unknown to the user
num_items = 100

# experimental configs
sample_size = 10_000
fs = np.linspace(0.1, 1.0, endpoint=True, num=10)

plt.figure(figsize=(16, 6))

for i, f in enumerate(fs, start=1):
    desc = f"{f=:.2f} ({i}/{len(fs)})"

    xs = [
        rs_once(make_stream(num_items), stop_at=(f * num_items))
        for _ in trange(sample_size, leave=False, desc=desc)
    ]

    plt.subplot(len(fs) // 5, 5, i)
    plt.hist(xs, weights=None, rwidth=0.75, align="mid")
    plt.title(desc)

plt.tight_layout()
pass

## `k > 1`

### w/ replacement

$\implies k$ parallel samples 

In [None]:
n = 20



from itertools import product
x = set(range(n))
f = t = 0
for i, j in product(x, repeat=2):
    f += int(i == 1 or j == 1)
    t += 1

print(t)
    
f / t, ((n - 1) + (n - 1) + 1) / n**2

(n**2 - n**2 + 2n - 1) / n**2

In [None]:
def rs_many_with_replacement(stream, k, stop_at):
    xs = list(stream)

    out = []
    for _ in range(k):
        s = rs_once((x for x in xs), stop_at)
        out.append(s)

    return out

In [None]:
stream = make_stream(num_items=100)
rs_many_with_replacement(stream, k=5, stop_at=50)

### w/o replacement

In [None]:
def rs_many_without_replacement(stream, k):
    cells = [None for _ in range(k)]

    for i in range(k):
        cells[i] = next(stream)

    while True:
        j = np.random.randint(1, i + 1, size=1).item()
        if j <= k:
            cells[j - 1] = next(stream)